{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f08ed48d-0a05-4b3e-bb52-d63600ed6772",
   "metadata": {},
   "source": [
    "# Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "202c9deb-f13f-4ca1-92fc-1922282cf0e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, num_features, num_classes):\n",
    "        super().__init__()\n",
    "\n",
    "        self.all_layers = torch.nn.Sequential(\n",
    "            torch.nn.Linear(num_features, 10),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Dropout(0.5),\n",
    "            \n",
    "            # output layer\n",
    "            torch.nn.Linear(10, num_classes),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        logits = self.all_layers(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1b3c1d6c-e39c-4b56-9e69-7bd55b8d97ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "model = MLP(num_features=5, num_classes=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08f6a192-5212-4d21-9ed5-b83c81c1b609",
   "metadata": {},
   "source": [
    "## Dropout during training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d4f54a58-dd64-4cba-9375-7d4698292002",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.tensor([1., 0.3, 2.4, -1.1, -0.8])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f104b59c-8f36-43d3-b3e0-3dcd4858c31a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.1564, -0.2977], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e8dea15d-e784-4492-97a9-329d96a4dd1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.1359, 0.0523], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fa5b0d3-b63c-4361-9176-3a3189caa540",
   "metadata": {},
   "source": [
    "## Disable dropout during inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "22705321-a2b0-4a38-adc9-a3c02a714ad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a289774a-b3f6-457b-9ae6-341e6a8081ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0458, -0.1777], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "71adeb90-b0d5-491b-98d2-65801d3e650c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0458, -0.1777], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84bb2fd0-37a2-4ebb-9984-eacdee0c166c",
   "metadata": {},
   "source": [
    "Note: during inference, it's also recommended to use either `torch.no_grad()` or `torch.inference_mode()` context so that gradient tracking is disabled. (Not used above to demonstrate that `.eval()` disables dropout during inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "636ddddb-e198-4ec5-acee-c52bf4f6a6ea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.0458, -0.1777])\n"
     ]
    }
   ],
   "source": [
    "with torch.inference_mode():\n",
    "    print(model(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dc8ffa6-6883-4b74-b0b9-32c98d3cf2fd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
